:orphan: Sklearn Basics 3: Train a Classifier on a Snowflake Multi-Table Dataset ======================================================================= In this notebook, we will learn how to train a classifier with a more complex multi-table data where a secondary table is itself a parent tables of another table (ie. snowflake schema). It is highly recommended to see the *Sklearn Basics 1* and *Sklearn Basics 2* lessons if you are not familiar with Khiops’ sklearn estimators. We start by importing the sklearn estimator ``KhiopsClassifier``: .. code:: ipython3 import os import pandas as pd from khiops import core as kh from khiops.sklearn import KhiopsClassifier, train_test_split_dataset from sklearn import metrics # If there are any issues you may Khiops status with the following command # kh.get_runner().print_status() Training a Multi-Table Classifier ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ We’ll train a multi-table classifier on a extension of dataset ``AccidentsSummary`` that we used in the previous notebook Sklearn Basics 2. This dataset ``Accidents`` contains the additional table ``Users`` and is organized in the following relational snowflake schema. :: Accidents | | -- 1:n -- Vehicles | | | |-- 1:n -- Users | | -- 1:1 -- Places Note that the target variable is ``Gravity``. To train the KhiopsClassifier for this setup, we must specify a multi-table dataset. Let’s first check the content of the tables: - The main table ``Accidents``. - The first secondary table ``Vehicles`` which has a ``1:n`` relationship with ``Accidents``. - The second secondary table ``Places`` which has a ``1:1`` relationship with ``Accidents``. - The tertiary table ``Users`` which has a ``1:n`` relationship with ``Vehicles``. .. code:: ipython3 accidents_dataset_dir = os.path.join(kh.get_samples_dir(), "Accidents") accidents_file = os.path.join(accidents_dataset_dir, "Accidents.txt") accidents_df = pd.read_csv(accidents_file, sep="\t") print(f"Accident dataframe (first 10 rows):") display(accidents_df.head(10)) print() vehicles_file = os.path.join(accidents_dataset_dir, "Vehicles.txt") vehicles_df = pd.read_csv(vehicles_file, sep="\t") print(f"Vehicle dataframe (first 10 rows):") display(vehicles_df.head(10)) # We drop the "Gravity" column as it was used to create the target users_file = os.path.join(accidents_dataset_dir, "Users.txt") users_df = pd.read_csv(users_file, sep="\t") print(f"User dataframe (first 10 rows):") display(users_df.head(10)) print() places_file = os.path.join(accidents_dataset_dir, "Places.txt") places_df = pd.read_csv(places_file, sep="\t", low_memory=False) print(f"Places dataframe (first 10 rows):") display(places_df.head(10)) .. parsed-literal:: Accident dataframe (first 10 rows): .. parsed-literal:: AccidentId Gravity Date Hour Light \ 0 201800000001 NonLethal 2018-01-24 15:05:00 Daylight 1 201800000002 NonLethal 2018-02-12 10:15:00 Daylight 2 201800000003 NonLethal 2018-03-04 11:35:00 Daylight 3 201800000004 NonLethal 2018-05-05 17:35:00 Daylight 4 201800000005 NonLethal 2018-06-26 16:05:00 Daylight 5 201800000006 NonLethal 2018-09-23 06:30:00 TwilightOrDawn 6 201800000007 NonLethal 2018-09-26 00:40:00 NightStreelightsOn 7 201800000008 Lethal 2018-11-30 17:15:00 NightStreelightsOn 8 201800000009 NonLethal 2018-02-18 15:57:00 Daylight 9 201800000010 NonLethal 2018-03-19 15:30:00 Daylight Department Commune InAgglomeration IntersectionType Weather \ 0 590 5 No Y-type Normal 1 590 11 Yes Square VeryGood 2 590 477 Yes T-type Normal 3 590 52 Yes NoIntersection VeryGood 4 590 477 Yes NoIntersection Normal 5 590 52 Yes NoIntersection LightRain 6 590 133 Yes NoIntersection Normal 7 590 11 Yes NoIntersection Normal 8 590 550 No NoIntersection Normal 9 590 51 Yes X-type Normal CollisionType PostalAddress GPSCode \ 0 2Vehicles-BehindVehicles-Frontal route des Ansereuilles M 1 NoCollision Place du général de Gaul M 2 NoCollision Rue nationale M 3 2Vehicles-Side 30 rue Jules Guesde M 4 2Vehicles-Side 72 rue Victor Hugo M 5 Other D39 M 6 Other 4 route de camphin M 7 Other rue saint exupéry M 8 Other rue de l'égalité M 9 2Vehicles-BehindVehicles-Frontal face au 59 rue de Lille M Latitude Longitude 0 50.55737 2.55737 1 50.52936 2.52936 2 50.51243 2.51243 3 50.51974 2.51974 4 50.51607 2.51607 5 50.52132 2.52132 6 50.52211 2.52211 7 50.53146 2.53146 8 50.53707 2.53707 9 50.53639 2.53639 .. parsed-literal:: Vehicle dataframe (first 10 rows): .. parsed-literal:: AccidentId VehicleId Direction Category PassengerNumber \ 0 201800000001 A01 Unknown Car<=3.5T 0 1 201800000001 B01 Unknown Car<=3.5T 0 2 201800000002 A01 Unknown Car<=3.5T 0 3 201800000003 A01 Unknown Motorbike>125cm3 0 4 201800000003 B01 Unknown Car<=3.5T 0 5 201800000003 C01 Unknown Car<=3.5T 0 6 201800000004 A01 Unknown Car<=3.5T 0 7 201800000004 B01 Unknown Bicycle 0 8 201800000005 A01 Unknown Moped 0 9 201800000005 B01 Unknown Car<=3.5T 0 FixedObstacle MobileObstacle ImpactPoint Maneuver 0 NaN Vehicle RightFront TurnToLeft 1 NaN Vehicle LeftFront NoDirectionChange 2 NaN Pedestrian NaN NoDirectionChange 3 StationaryVehicle Vehicle Front NoDirectionChange 4 NaN Vehicle LeftSide TurnToLeft 5 NaN NaN RightSide Parked 6 NaN Other RightFront Avoidance 7 NaN Vehicle LeftSide NaN 8 NaN Vehicle RightFront PassLeft 9 NaN Vehicle LeftFront Park .. parsed-literal:: User dataframe (first 10 rows): .. parsed-literal:: AccidentId VehicleId Seat Category Gender TripReason SafetyDevice \ 0 201800000001 A01 1.0 Driver Male Leisure SeatBelt 1 201800000001 B01 1.0 Driver Male NaN SeatBelt 2 201800000002 A01 1.0 Driver Male NaN SeatBelt 3 201800000002 A01 NaN Pedestrian Male NaN Helmet 4 201800000003 A01 1.0 Driver Male Leisure Helmet 5 201800000003 C01 1.0 Driver Male NaN ChildrenDevice 6 201800000004 A01 1.0 Driver Male Leisure SeatBelt 7 201800000004 B01 1.0 Driver Male Leisure Helmet 8 201800000005 A01 1.0 Driver Male Leisure Helmet 9 201800000005 B01 1.0 Driver Male Leisure SeatBelt SafetyDeviceUsed PedestrianLocation PedestrianAction \ 0 Yes NaN NaN 1 Yes NaN NaN 2 Yes NaN NaN 3 NaN OnLane<=OnSidewalk0mCrossing Crossing 4 Yes NaN NaN 5 NaN NaN NaN 6 Yes NaN NaN 7 NaN NaN NaN 8 Yes NaN NaN 9 Yes NaN NaN PedestrianCompany BirthYear 0 Unknown 1960.0 1 Unknown 1928.0 2 Unknown 1947.0 3 Alone 1959.0 4 Unknown 1987.0 5 Unknown 1977.0 6 Unknown 1982.0 7 Unknown 2013.0 8 Unknown 2001.0 9 Unknown 1946.0 .. parsed-literal:: Places dataframe (first 10 rows): .. parsed-literal:: AccidentId RoadType RoadNumber RoadSecNumber RoadLetter \ 0 201800000001 Departamental 41 NaN C 1 201800000002 Communal 41 NaN D 2 201800000003 Departamental 39 NaN D 3 201800000004 Departamental 39 NaN NaN 4 201800000005 Communal NaN NaN NaN 5 201800000006 Departamental 39 NaN D 6 201800000007 Departamental 41 NaN D 7 201800000008 Communal - NaN NaN 8 201800000009 Departamental 141 NaN D 9 201800000010 Departamental 641 NaN NaN Circulation LaneNumber SpecialLane Slope RoadMarkerId \ 0 TwoWay 2.0 0 Flat NaN 1 TwoWay 2.0 0 Flat NaN 2 TwoWay 2.0 0 Flat NaN 3 TwoWay 2.0 0 Flat NaN 4 OneWay 1.0 0 Flat NaN 5 Unknown 2.0 0 Uphill NaN 6 TwoWay 2.0 0 Flat 16.0 7 TwoWay 2.0 0 Flat NaN 8 TwoWay 2.0 0 Flat NaN 9 TwoWay 2.0 Bike Flat 1.0 RoadMarkerDistance Layout StripWidth LaneWidth SurfaceCondition \ 0 NaN RightCurve NaN NaN Normal 1 NaN LeftCurve NaN NaN Normal 2 NaN Straight NaN NaN Normal 3 NaN Straight NaN NaN Normal 4 NaN Straight NaN NaN Normal 5 NaN LeftCurve NaN NaN Wet 6 500.0 Straight NaN NaN Normal 7 NaN Straight NaN NaN Normal 8 NaN Straight NaN NaN Normal 9 670.0 Straight NaN NaN Normal Infrastructure Localization SchoolNear 0 Unknown Lane 0.0 1 Unknown Lane 0.0 2 Unknown Lane 0.0 3 Unknown Lane 0.0 4 Unknown Lane 0.0 5 Unknown Shoulder 0.0 6 Unknown Shoulder 0.0 7 Unknown Lane 0.0 8 Unknown Shoulder 0.0 9 Unknown Lane 0.0 Create the multi-table dataset specification ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ Note the main table ``Accidents`` and the secondary table ``Places`` have one key ``AccidentId``. Tables ``Vehicles`` (the other secondary table) and ``Users`` (the tertiary table) have two keys: ``AccidentId`` and ``VehicleId``. To describe relations between tables, we add the ``relations`` field must to the dataset spec. This field contains a list of tuples describing the relations between tables. The first two values (``str``) of each tuple correspond to names of both the parent and the child table involved in the relation. A third value (``bool``) can be optionally set as ``True`` to indicate that the relation is ``1:1``. For example, if the tuple ``(table1, table2, True)`` is contained in this field, it means that: - ``table1`` and ``table2`` are in a ``1:1`` relationship - The key of ``table1`` is contained in that of ``table2`` (ie. keys are hierarchical) If the ``relations`` field is not present then Khiops Python assumes that the tables are in a *star* schema with ``main_table`` as the central table. .. code:: ipython3 X_accidents = { "main_table": "Accidents", "tables": { "Accidents": (accidents_df.drop("Gravity", axis=1), "AccidentId"), "Vehicles": (vehicles_df, ["AccidentId", "VehicleId"]), "Users": (users_df, ["AccidentId", "VehicleId"]), "Places": (places_df, "AccidentId"), }, "relations": [ ("Accidents", "Vehicles"), ("Vehicles", "Users"), ("Accidents", "Places", True), ], } y_accidents = accidents_df["Gravity"] Split the dataset into train and test ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ We use the helper function ``train_test_split_dataset`` with the ``X`` dataset spec to obtain one spec for train and another for test. .. code:: ipython3 ( X_accidents_train, X_accidents_test, y_accidents_train, y_accidents_test, ) = train_test_split_dataset(X_accidents, y_accidents, test_size=0.3) Train a classifier with this dataset ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - You may choose the number of features ``n_features`` to be created by the Khiops AutoML engine - Set the number of trees to zero (``n_trees=0``) .. code:: ipython3 khc_accidents = KhiopsClassifier(n_trees=0, n_features=1000) khc_accidents.fit(X_accidents_train, y_accidents_train) .. raw:: html
KhiopsClassifier(n_features=1000, n_trees=0)
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
Print the train accuracy and train auc of the model ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ .. code:: ipython3 accidents_train_performance = ( khc_accidents.model_report_.train_evaluation_report.get_snb_performance() ) print(f"Accidents train accuracy: {accidents_train_performance.accuracy}") print(f"Accidents train auc : {accidents_train_performance.auc}") .. parsed-literal:: Accidents train accuracy: 0.945238 Accidents train auc : 0.844578 Deploy the classifier to obtain predictions and probabilities on the test data ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ .. code:: ipython3 y_accidents_test_predicted = khc_accidents.predict(X_accidents_test) probas_accidents_test = khc_accidents.predict_proba(X_accidents_test) print("Accidents test predictions (first 10 values):") display(y_accidents_test_predicted[:10]) print("Accidentns test prediction probabilities (first 10 values):") display(probas_accidents_test[:10]) .. parsed-literal:: Accidents test predictions (first 10 values): .. parsed-literal:: array(['NonLethal', 'NonLethal', 'NonLethal', 'NonLethal', 'NonLethal', 'NonLethal', 'NonLethal', 'NonLethal', 'NonLethal', 'NonLethal'], dtype=object) .. parsed-literal:: Accidentns test prediction probabilities (first 10 values): .. parsed-literal:: array([[0.00561741, 0.99438259], [0.00743387, 0.99256613], [0.20684597, 0.79315403], [0.0210582 , 0.9789418 ], [0.01925004, 0.98074996], [0.03408831, 0.96591169], [0.11247547, 0.88752453], [0.07489168, 0.92510832], [0.00639295, 0.99360705], [0.42633439, 0.57366561]]) Estimate the accuracy and AUC metrics on the test data ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ .. code:: ipython3 accidents_test_accuracy = metrics.accuracy_score( y_accidents_test, y_accidents_test_predicted ) accidents_test_auc = metrics.roc_auc_score( y_accidents_test, probas_accidents_test[:, 1] ) print(f"Accidents test accuracy: {accidents_test_accuracy}") print(f"Accidents test auc : {accidents_test_auc}") .. parsed-literal:: Accidents test accuracy: 0.9445630227862706 Accidents test auc : 0.8304566361006444